
from pickle import FALSE, NONE, TRUE
import torch 
from torch import nn
from torch import optim
import torch.nn.functional as F
from copy import deepcopy
import random
import numpy as np
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
import sys
import threading
import shelve
from constant import *


ENV='Tennis'

GAMMA=0.99
MAX_STEPS=2000
NUM_EPISODES=100000
CAPACITY=200000
UPDATE_STEP=10
SAVE_STEP=40
Seq_max=5
SKIP=1
LR=0.0001
p=(0.1,0.2)

# Transition=namedtuple(
#     'Transition',('state','action','next_state','reward'))


class Net(nn.Module):
    def __init__(self,input_dim,output_dim):
        super(Net,self).__init__()
        self.m=nn.Sequential(
            nn.Linear(input_dim,128),
            nn.ReLU(inplace=True),
            nn.Linear(128,100),
            nn.ReLU(inplace=True)
        )
        self.v=nn.Sequential(
            nn.Linear(100,1)
        )

        self.a=nn.Sequential(
            nn.Linear(100,output_dim)
        )


    def forward(self,x):
        x=self.m(x)
        v=self.v(x)
        a=self.a(x)
        output=v+a-a.mean(dim=1).view(-1,1).expand(a.shape[0],a.shape[1])
        return output

        


class Memory:
    def __init__(self,reload_mem=False):
        self.cap=CAPACITY
        self.memory=[]
        self.s_memory=[]
        self.index=0
        self.s_index=0
        self.num=0
        self.save_num=100000
        if reload_mem:
            self.load_mem()
            pass
    
    def Transition(self,state,action,state_next,reward):
        data={}
        data['state']=state
        data['action']=action
        data['state_next']=state_next
        data['reward']=reward
        return data

    def push_s(self,state,action,state_next,reward):
        # print(action,'good',reward)

        if(len(self.s_memory))<self.cap:
            self.s_memory.append(self.Transition(state,action,state_next,reward))
        else:
            if(self.s_index%self.cap==0):
                self.s_index=0
            self.s_memory[self.s_index]=self.Transition(state,action,state_next,reward)

        self.s_index+=1
        
        self.num+=1
        if self.num%self.save_num==0:
            self.num=0
            self.save_mem()
        pass

    def push(self,state,action,state_next,reward):
        if(len(self.memory))<self.cap:
            self.memory.append(self.Transition(state,action,state_next,reward))
        else:
            if(self.index%self.cap==0):
                self.index=0
            self.memory[self.index]=self.Transition(state,action,state_next,reward)

        self.index+=1
        self.num+=1
        
        if self.num%self.save_num==0:
            self.num=0
            self.save_mem()
    
    def sample(self,batch_size):
        s_mem_n=batch_size if len(self.s_memory)>batch_size else len(self.s_memory)
        if s_mem_n>batch_size:
            s_mem_n/=2
        s_mem_n=int(s_mem_n)
        sam=random.sample(self.s_memory,s_mem_n)
        sam+=random.sample(self.memory,batch_size-s_mem_n)

        batch={}
        first=True
        for i in sam:
            # print(i)
            if i['state_next'] is None:
                continue
            if first:
                batch['state']=i['state']
                batch['action']=i['action']
                batch['state_next']=i['state_next']
                batch['reward']=i['reward']
                first=True
            else:
                batch['state']=torch.cat((batch['state'],i['state']))
                batch['action']=torch.cat((batch['action'],i['action']))
                batch['state_next']=torch.cat((batch['state_next'],i['state_next']))
                batch['reward']=torch.cat((batch['reward'],i['reward']))



        return batch

    def __len__(self):
        return len(self.memory)+len(self.s_memory)

    def is_experiecd(self,batch_size=128):
        if (len(self.memory))>batch_size:
            return True
        return False

    def save_mem(self):
        # return None
        db=shelve.open(ENV+'.EXP')
        db['normal']=self.memory
        db['success']=self.s_memory
        db.close()
        # with open('ENV'+'_mem.exp','wb') as pd:
        #     pickle.dump(self.memory,pd)
        # with open('ENV'+'_s_mem.exp','wb') as pd1:
        #     pickle.dump(self.s_memory,pd1)
        # np.save(ENV+'mem.npy',np.array(self.memory),allow_pickle=True)
        # np.save(ENV+'s_mem.npy',np.array(self.s_memory),allow_pickle=True)
    
    def load_mem(self):
        db=shelve.open(ENV+'.EXP')
        self.memory=db['normal']
        self.s_memory=db['success']
        db.close()
        self.index=len(self.memory)
        self.s_index=len(self.s_memory)
        print('load sucess %d normal %d'%(len(self.s_memory),len(self.memory)))



class Brain:
    def __init__(self,batchsize=3,num_states=10,num_actions=10,reload=None,load_mem=False):

        self.batch_size=batchsize
        self.num_states=num_states*Seq_max
        self.num_actions=num_actions
        
        self.Q=Net(num_states,num_actions)

        if not reload  is None:
            self.load(reload)

        self.TD=deepcopy(self.Q)

        self.optim=optim.Adam(self.Q.parameters(),lr=LR)

        self.mem=Memory(reload_mem=load_mem)
        print(self.Q)
        self.device=torch.device('cpu')
        if torch.cuda.is_available():
            self.device=torch.device('cuda')
            self.Q=self.Q.cuda()
            self.TD=self.TD.cuda()


    def replay(self):
        if not self.mem.is_experiecd(batch_size=self.batch_size):
            return 
        
        batch=self.mem.sample(self.batch_size)
        
        # print(batch.state)
        
        state_batch=batch['state']
        action_batch=batch['action']
        reward_batch=batch['reward']
        non_final_next_state_batch=batch['state_next']
        if torch.cuda.is_available():
            state_batch=state_batch.to(self.device)
            action_batch=action_batch.to(self.device)
            reward_batch=reward_batch.to(self.device)
            non_final_next_state_batch=non_final_next_state_batch.to(self.device)

        self.Q.eval()
        qs=self.Q(state_batch).gather(1,action_batch)

        self.TD.eval()
        next_vals=torch.zeros(self.batch_size)
        if torch.cuda.is_available():
            next_vals=next_vals.to(self.device)
        # bools=[]
        # for s in batch.next_state:
        #     if s is None:
        #         bools.append(False)
        #     else:
        #         bools.append(True)
        
        next_vals=(self.TD(non_final_next_state_batch).max(1)[0])
        td=reward_batch+GAMMA*next_vals
        td=td.view(-1,1)
        self.Q.train()
        loss=F.smooth_l1_loss(qs,td)
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()

    def decide(self,state,episode):

        epsilon=p[0]+(p[1]-p[0])*(NUM_EPISODES-episode)/NUM_EPISODES
        if np.random.uniform(0,1)<epsilon:
            a=random.randrange(self.num_actions)
            if a in attack_acts:
                a=random.choice(down_attack)
            action=torch.LongTensor([[a]])
        
        else:
            self.Q.eval()
            if torch.cuda.is_available():
                state=state.to(self.device)
            action=self.Q(state).max(1)[1].view(1,-1)

        return action

    def update(self):
        self.TD.load_state_dict(self.Q.state_dict())

    def save(self,path=None):
        if path is None:
            torch.save(self.Q.state_dict(),ENV+'.pth')
        else:
            torch.save(self.Q.state_dict(),path)
    
    def load(self,path=None):
        if path is None:
            if torch.cuda.is_available():
                self.Q.load_state_dict(torch.load(ENV+'.pth'))
            else:
                self.Q.load_state_dict(torch.load(ENV+'.pth',map_location='cpu'))
                
        else:
            if torch.cuda.is_available():
                self.Q.load_state_dict(torch.load(path))
            else:
                self.Q.load_state_dict(torch.load(path,map_location='cpu'))











if __name__=='__main__':
    net=Net(10,10)
    t=torch.rand(1,10)
    print(t.mean(1))
    output=net(t)
    
    # b=torch.rand(1,10)
    print(output)

    print((1000000/1024))

    a=np.array([1,2,3,4])
    np.save('test.npy',a)
    b=np.load('test.npy')
    print(b)

